import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
from torch import optim
from torch.utils.data import DataLoader
from data import CustomDataset
from matching_network import MatchingNetwork, matching_loss, compute_accuracy, process_batch_matching
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm


def set_seed(seed):
    """设置随机种子以确保结果可重现"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main():
    # 设置随机种子
    seed = 42
    set_seed(seed)

    # 设备配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")

    # 数据集参数
    num_classes = 10  # 类别数
    num_support = 10  # 每类的支持集样本数
    num_query = 15    # 每类的查询集样本数

    # 加载数据
    print("加载数据...")
    dictdata = loadmat("traindata_5dB.mat")
    dataload = dictdata['traindata']

    # 网络参数
    input_size = 8      # 输入特征维度
    hidden_size = 256   # 隐藏层大小
    output_size = 64    # 嵌入空间大小

    # 创建Matching Network实例
    print("创建Matching Network模型...")
    matching_net = MatchingNetwork(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        nhead=8,
        num_layers=2,
        seq_len=8,
        use_fce=False,  # 暂时不使用FCE以避免维度问题
        attention_type='cosine'  # 使用余弦相似度注意力
    )
    matching_net.to(device)

    # 优化器
    optimizer = optim.Adam(matching_net.parameters(), lr=0.001)

    # 创建数据集和数据加载器
    print("准备数据集...")
    dataset = CustomDataset(dataload, noisy=False)

    # 分割训练集和测试集
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.9, random_state=42)

    batch_size = num_classes * (num_support + num_query)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    # 检查数据形状
    x, y = next(iter(train_loader))
    print(f"数据形状: {x.shape}")
    print(f"标签形状: {y.shape}")

    # 训练参数
    num_epochs = 50  # 减少训练轮数以便快速测试
    log_file = "matching_networks_training_log.txt"

    print("开始训练...")

    with open(log_file, 'w', encoding='utf-8') as f:
        f.write("Matching Networks 训练日志\\n")
        f.write(f"模型参数: input_size={input_size}, hidden_size={hidden_size}, output_size={output_size}\\n")
        f.write(f"训练参数: num_classes={num_classes}, num_support={num_support}, num_query={num_query}\\n")
        f.write("=" * 50 + "\\n")

        best_accuracy = 0.0

        for epoch in range(num_epochs):
            # 训练阶段
            matching_net.train()
            epoch_loss = 0
            batch_count = 0

            for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
                # 处理批次数据
                predictions, query_labels = process_batch_matching(
                    matching_net, batch, device, num_classes, num_support
                )

                # 计算损失
                loss = matching_loss(predictions, query_labels)

                # 反向传播和优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                batch_count += 1

            avg_loss = epoch_loss / batch_count if batch_count > 0 else 0

            # 测试阶段
            matching_net.eval()
            test_accuracy = 0
            test_loss = 0
            test_batch_count = 0

            with torch.no_grad():
                for batch in test_loader:
                    predictions, query_labels = process_batch_matching(
                        matching_net, batch, device, num_classes, num_support
                    )

                    # 计算测试损失和准确率
                    loss = matching_loss(predictions, query_labels)
                    accuracy = compute_accuracy(predictions, query_labels)

                    test_loss += loss.item()
                    test_accuracy += accuracy
                    test_batch_count += 1

            avg_test_loss = test_loss / test_batch_count if test_batch_count > 0 else 0
            avg_test_accuracy = test_accuracy / test_batch_count if test_batch_count > 0 else 0
            avg_test_accuracy_percent = avg_test_accuracy * 100

            # 保存最佳模型
            if avg_test_accuracy > best_accuracy:
                best_accuracy = avg_test_accuracy
                torch.save(matching_net.state_dict(), 'best_matching_network.pth')

            # 打印和记录结果
            print(f'Epoch {epoch + 1}/{num_epochs}:')
            print(f'  训练损失: {avg_loss:.4f}')
            print(f'  测试损失: {avg_test_loss:.4f}')
            print(f'  测试准确率: {avg_test_accuracy_percent:.2f}%')
            print(f'  最佳准确率: {best_accuracy * 100:.2f}%')
            print('-' * 50)

            f.write(f'Epoch {epoch + 1}: 训练损失={avg_loss:.4f}, 测试损失={avg_test_loss:.4f}, ')
            f.write(f'测试准确率={avg_test_accuracy_percent:.2f}%, 最佳准确率={best_accuracy * 100:.2f}%\\n')

    # 保存最终模型
    torch.save(matching_net.state_dict(), 'final_matching_network.pth')
    print(f"训练日志已保存到: {log_file}")
    print(f"最佳模型已保存到: best_matching_network.pth")
    print(f"最终模型已保存到: final_matching_network.pth")

    # 最终测试和混淆矩阵
    print("\\n进行最终测试...")
    matching_net.eval()

    all_predictions = []
    all_labels = []
    final_correct = 0
    final_total = 0

    with torch.no_grad():
        for batch in test_loader:
            predictions, query_labels = process_batch_matching(
                matching_net, batch, device, num_classes, num_support
            )

            # 获取预测类别
            _, predicted_classes = torch.max(predictions, dim=1)

            # 累计结果
            all_predictions.extend(predicted_classes.cpu().numpy())
            all_labels.extend(query_labels.cpu().numpy())

            final_correct += (predicted_classes == query_labels).sum().item()
            final_total += query_labels.size(0)

    final_accuracy = 100 * final_correct / final_total
    print(f'最终测试准确率: {final_accuracy:.2f}%')

    # 生成混淆矩阵
    print("生成混淆矩阵...")
    cm = confusion_matrix(all_labels, all_predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Matching Networks - 混淆矩阵')
    plt.savefig('matching_networks_confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("训练完成！")


if __name__ == '__main__':
    main()
